# Read in 'data/ml_hypothesis_generation/rating_results_long.csv'

rating_results_long <- read_csv("data/ml_hypothesis_generation/rating_results_long.csv") %>% 
filter(!is.na(rating_val))

# Read in 'data/ml_hypothesis_generation/hypotheses_tbl.csv'

hypotheses_tbl <- read_csv("data/ml_hypothesis_generation/hypotheses_tbl.csv")

set.seed(12345)

rerun_embeddings <- FALSE

if (rerun_embeddings) {

  # Ensure OpenAI API key is set
  Sys.setenv(OPENAI_API_KEY = # INSERT KEY HERE
  )

  # Function to get embeddings for multiple texts at once
  get_embeddings <- function(texts) {
    tryCatch({
      response <- create_embedding(
        model = "text-embedding-ada-002",
        input = texts
      )
      # Extract all embedding vectors
      embeddings <- map(response$data$embedding, function(emb) {
        paste(emb, collapse = ";")
      })
      unlist(embeddings)
    }, error = function(e) {
      warning(sprintf("Error getting embeddings: %s", e$message))
      return(rep(NA, length(texts)))
    })
  }

  # Add embeddings to hypotheses_tbl
  hypotheses_with_embeddings <- hypotheses_tbl %>%
    mutate(
      embedding = get_embeddings(hypothesis_cleaned),
      embedding_timestamp = Sys.time()
    ) %>%
    mutate(embedding_list = str_split(embedding, ";")) %>%
    mutate(embedding_list = map(embedding_list, as.numeric))

  # Save the results into RData
  save(hypotheses_with_embeddings, file = "data/ml_hypothesis_generation/hypotheses_with_embeddings.RData")

}

# Load the results
load("data/ml_hypothesis_generation/hypotheses_with_embeddings.RData")


# Convert embeddings to numeric matrix
as.numeric.matrix <- function(x) {
  matrix(as.numeric(x), nrow = nrow(x), ncol = ncol(x))
}

use_embeddings <- TRUE

if (use_embeddings) {

# Convert string embeddings to matrix
embeddings_char <- hypotheses_with_embeddings %>%
  dups_report(hypothesis_cleaned) %>%
  dups_drop(hypothesis_cleaned) %>%
  .$embedding %>%
  str_split_fixed(pattern = ";", n = 1536)  # 1536 is OpenAI's embedding dimension
embeddings <- matrix(as.numeric(embeddings_char), nrow = nrow(embeddings_char), ncol = ncol(embeddings_char))

# Step 1: Project embeddings to 2D using t-SNE
library(Rtsne)
tsne_result <- Rtsne(embeddings, dims = 2, perplexity = 30, verbose = TRUE, max_iter = 500)

# Add t-SNE results to dataframe
hypotheses_with_pca <- hypotheses_with_embeddings %>%
  dups_drop(hypothesis_cleaned) %>%
  mutate(
    pca_1 = tsne_result$Y[, 1],
    pca_2 = tsne_result$Y[, 2]
  )

# Step 2: Perform k-means clustering
n_clusters <- 50  # You can adjust this number
kmeans_result <- kmeans(embeddings, centers = n_clusters)

# Add cluster assignments to dataframe
hypotheses_with_clusters <- hypotheses_with_pca %>%
  mutate(cluster = kmeans_result$cluster)

# Step 3: Visualize the results
ggplot(hypotheses_with_clusters, aes(x = pca_1, y = pca_2, color = factor(cluster))) +
  geom_point(alpha = 0.6) +
  theme_minimal() +
  labs(title = "Hypothesis Embeddings Visualization with Cluster Labels",
       x = "t-SNE feature 1",
       y = "t-SNE feature 2")

# Get representative hypotheses from each cluster
# Calculate cluster means
cluster_means <- tibble(
  cluster = 1:nrow(kmeans_result$centers),
  cluster_mean_embedding = kmeans_result$centers %>%
    apply(1, function(x) x, simplify = FALSE)
)

# Calculate cosine similarity
cosine_similarity_paired <- function(embeddings1, embeddings2) {
  if (is.list(embeddings1)) {
    embeddings1 <- do.call(rbind, embeddings1)
  }
  if (is.list(embeddings2)) {
    embeddings2 <- do.call(rbind, embeddings2)
  }
  dot_products <- rowSums(embeddings1 * embeddings2)
  norm1 <- sqrt(rowSums(embeddings1^2))
  norm2 <- sqrt(rowSums(embeddings2^2))
  similarities <- dot_products / (norm1 * norm2)
  return(similarities)
}

# Add embeddings as list column and calculate similarities
hypotheses_with_clusters <- hypotheses_with_clusters %>%
  mutate(
    embedding_vector = embedding %>% str_split(pattern = ";") %>% map(as.numeric)
  ) %>%
  left_join(cluster_means, by = "cluster") %>%
  mutate(
    similarity_to_cluster_mean = cosine_similarity_paired(embedding_vector, cluster_mean_embedding)
  )

# Get most representative hypotheses for each cluster
representative_hypotheses <- hypotheses_with_clusters %>%
  group_by(cluster) %>%
  arrange(desc(similarity_to_cluster_mean)) %>%
  slice_head(n = 3) %>%
  ungroup() %>%
  select(cluster, hypothesis_cleaned, similarity_to_cluster_mean)

}


rating_results_long %>% hist_basic(rating_val, binwidth =1, boundary = 0)
rating_results_long %>% glimpse

r1_choices_discussion <- r1_choices %>% 
      filter(discussion_full == 1) %>% 
      select(group_id, phase,  round, discuss_type, r1_choose_comparator, r1_choose_trans, pair_includes_trans) %>% 
      tidylog::distinct() %>% 
      dups_report(group_id, round) %>% 
      dups_drop(group_id, round)


# Merge rating results with r1 choices
rating_results_with_r1 <- rating_results_long %>%
  tidylog::left_join(
    r1_choices_discussion,
    by = c("group_id", "pair_id" = "round")
  )


# Optionally: reduce dimensionality of all the hypotheses by:
# 1. Getting a dataset at the group_id x pair_id level with one y variable for each hypothesis
# Create wide dataset with one column per hypothesis rating
rating_results_super_wide <- rating_results_with_r1 %>%
  select(group_id, pair_id, hypothesis_id, rating_val) %>%
  pivot_wider(
    id_cols = c(group_id, pair_id),
    names_from = hypothesis_id,
    values_from = rating_val,
    names_prefix = "hyp_"
  ) %>%
  # Remove any duplicate group_id/pair_id combinations
  dups_drop(group_id, pair_id)

rating_results_super_wide %>% 
count_nas()


# 2. Using factor analysis to see which hypotheses load on which factors
# n_factors_data <- rating_results_super_wide %>% 
# select(-group_id, -pair_id) %>% 
# drop_na() %>% 
# n_factors()
# PARALLEL ANLYSIS SUGGEST 11 factors

# 5 factors more or less
n_factors_manual <- 11

hypotheses_loadings <- rating_results_super_wide %>% 
select(-group_id, -pair_id) %>% 
drop_na() %>% 
factor_loadings(n_factors = n_factors_manual) %>% 
arrange(desc(MR2)) %>% 
mutate(hypothesis_id = var %>% str_remove("hyp_") %>% as.numeric()) %>% 
left_join(
  hypotheses_tbl %>% select(hypothesis_id, hypothesis_cleaned),
  by = "hypothesis_id"
) %>% 
relocate(hypothesis_cleaned, .before = var) %>% 
rename_with(
  .cols = starts_with("MR"),
  .fn = ~str_replace(.x, "MR", "fact_")
)


# Aggregate ratings to hypothesis x group x pair_includes_trans level
hypothesis_ratings_agg <- rating_results_with_r1 %>%
  group_by(hypothesis_id, group_id, pair_includes_trans) %>%
  summarise(
    avg_rating = mean(rating_val, na.rm = TRUE),
    n_ratings = n()
  ) %>%
  ungroup()

# Reshape to wide format
hypothesis_ratings_wide <- hypothesis_ratings_agg %>%
  pivot_wider(
    id_cols = c(hypothesis_id, group_id),
    names_from = pair_includes_trans,
    values_from = avg_rating,
    names_prefix = "rating_trans_"
  ) %>%
  # Fill any missing values with 0
  mutate(
    rating_non_trans = coalesce(rating_trans_0, 0),
    rating_trans = coalesce(rating_trans_1, 0)
  ) %>%
  dups_drop(hypothesis_id, group_id)


r2_choices_num_trans <- r2_choices_num %>% 
filter(!is.na(r2_choose_trans))


# Run separate regressions for each hypothesis
hypotheses_list <- unique(hypothesis_ratings_wide$hypothesis_id)
model_results <- list()

for (hyp_id in hypotheses_list) {
  print(str_glue("Running model for hypothesis {hyp_id}"))

  # Filter data for this hypothesis and merge with outcomes
  hyp_data <- hypothesis_ratings_wide %>%
    filter(hypothesis_id == hyp_id) %>%
    left_join(r2_choices_num_trans, by = "group_id")
  
  # Run model
  model <- feols_custom(
    r2_choose_trans ~ rating_trans + rating_non_trans,
    data = hyp_data,
    cluster = "group_id"
  )
  
  # Store results
  model_results[[hyp_id]] <- model %>% tidy_90 %>%
    mutate(hypothesis_id = hyp_id)
  
  rm(hyp_data, model)
}
# Combine results and join with hypothesis text
hypothesis_effects <- bind_rows(model_results) %>%
  left_join(hypotheses_tbl, by = "hypothesis_id")

# Plot results
hypothesis_effects %>%
  filter(term != "(Intercept)") %>%
  ggplot(
    aes(x = hypothesis_cleaned,
        y = estimate,
        ymin = conf.low,
        ymax = conf.high,
        color = term)) +
  geom_pointrange(position = position_dodge(width = 0.5)) +
  geom_hline(yintercept = 0, linetype = "dashed", alpha = 0.5) +
  coord_flip() +
  scale_color_manual(
    values = c("rating_trans" = "blue", "rating_non_trans" = "red"),
    labels = c("rating_trans" = "Trans Rating", "rating_non_trans" = "Non-Trans Rating")
  ) +
  labs(
    title = "Hypothesis Effects on R2 Trans Choice",
    x = "Hypothesis", 
    y = "Effect Size",
    color = "Rating Type"
  ) +
  theme_minimal() +
  theme(
    axis.text.y = element_text(size = 8),
    legend.position = "bottom"
  )


# USE FACTOR INDICES INSTEAD
factor_thresh <- 0.3

# Calculate factor indices
factor_indices <- rating_results_with_r1 %>%
  group_by(group_id) %>% 
  mutate(n_r1_choose_trans = mean_na(r1_choose_trans)) %>%  # Calculate mean r1_choose_trans per group
  # Join with hypotheses loadings
  left_join(
    hypotheses_loadings %>% 
      select(hypothesis_id, matches("fact_")),
    by = "hypothesis_id"
  ) %>%
  # For each group and pair type, calculate weighted average of ratings
  group_by(group_id, pair_includes_trans) %>%
  summarise(
    n_r1_choose_trans = first(n_r1_choose_trans),  # Keep n_r1_choose_trans in summarise
    across(starts_with("fact_"), function(factor_loading) {
      # Weighted sum of ratings divided by sum of absolute loadings
      factor_loading <- if_else(factor_loading < factor_thresh, 0, factor_loading)
      weighted_sum <- sum_na(rating_val * factor_loading)
      loading_sum <- sum_na(abs(factor_loading))
      ifelse(loading_sum == 0, NA, weighted_sum / loading_sum)
    })
  ) %>%
  ungroup() %>%
  # Normalize each factor to mean 0, sd 1
  mutate(
    across(starts_with("fact_"), ~scale(.x)[,1])
  )

# Verify the results
factor_indices %>%
  summarise(
    across(starts_with("fact_"), list(
      mean = ~mean_na(.x) %>% round(3),
      sd = ~sd(.x, na.rm = TRUE)
    ))
  ) %>% 
  glimpse

# Reshape factor indices to wide format
factor_indices_wide <- factor_indices %>%
  pivot_wider(
    id_cols = c(group_id, n_r1_choose_trans),  # Include n_r1_choose_trans in id_cols
    names_from = pair_includes_trans,
    values_from = starts_with("fact_"),
    names_glue = "{.value}_trans_{pair_includes_trans}"
  ) %>%
  # Rename columns to simpler format
  rename_with(
    ~str_replace(.x, "_trans_0", "_non_trans"),
    matches("_trans_0")
  ) %>%
  rename_with(
    ~str_replace(.x, "_trans_1", "_trans"),
    matches("_trans_1")
  )

# Run separate regressions for each factor
factor_names <- str_subset(names(factor_indices), "^fact_")
model_results <- list()
split_models <- TRUE  # Toggle between combined and split models
conditional_on_choice <- FALSE

r2_choices_num_for_lasso <- r2_choices_num %>% 
group_by(group_id) %>% 
mutate(employer_group_control = (sum_na(employer) - employer)) %>% 
ungroup()

for (factor_name in factor_names) {
  print(str_glue("Running model for {factor_name}"))
  
  # Create variable names for trans and non-trans versions
  trans_var <- paste0(factor_name, "_trans")
  non_trans_var <- paste0(factor_name, "_non_trans")
  
  # Base formula components
  base_formula <- "r2_choose_trans ~ item_diff + employer_group_control" # these are the LASSO controls
  
  # Add conditional control if enabled
  if (conditional_on_choice) {
    base_formula <- paste(base_formula, "+ n_r1_choose_trans")
  }
  
  if (!split_models) {
    # Combined model approach
    formula <- as.formula(paste(base_formula, "+", trans_var, "+", non_trans_var))
    model <- feols_custom(
      formula,
      data = factor_indices_wide %>% left_join(r2_choices_num_for_lasso, by = "group_id"),
      cluster = "group_id",
      fixef = c("stratum_id", "video_type", "delivery_incentive_exp", "comparator_order_in_pair", "phase")
    )
    results <- model %>% tidy_90 %>% mutate(factor = factor_name)
  } else {
    # Split models approach
    formula_trans <- as.formula(paste(base_formula, "+", trans_var))
    formula_non_trans <- as.formula(paste(base_formula, "+", non_trans_var))
    
    model_trans <- feols_custom(
      formula_trans,
      data = factor_indices_wide %>% left_join(r2_choices_num_for_lasso, by = "group_id"),
      cluster = "group_id",
      fixef = c("stratum_id", "video_type", "delivery_incentive_exp", "comparator_order_in_pair", "phase")
    )
    
    model_non_trans <- feols_custom(
      formula_non_trans,
      data = factor_indices_wide %>% left_join(r2_choices_num_for_lasso, by = "group_id"),
      cluster = "group_id",
      fixef = c("stratum_id", "video_type", "delivery_incentive_exp", "comparator_order_in_pair", "phase")
    )
    
    results <- bind_rows(
      model_trans %>% tidy_90 %>% filter(str_detect(term, "fact_\\d+_trans|fact_\\d+_non_trans")),
      model_non_trans %>% tidy_90 %>% filter(str_detect(term, "fact_\\d+_trans|fact_\\d+_non_trans"))
    ) %>% mutate(factor = factor_name)
  }
  
  # Store results
  model_results[[factor_name]] <- results
}

# Combine results and add FDR-adjusted q-values
factor_effects <- bind_rows(model_results) %>%
  # Clean up term names for plotting
  mutate(
    rating_type = case_when(
      str_detect(term, "_non_trans$") ~ "Non-Trans Rating",
      str_detect(term, "_trans$") ~ "Trans Rating",
      TRUE ~ term
    )
  ) %>%
  filter(term != "(Intercept)") %>%
  # Add FDR-adjusted q-values
  ungroup %>% 
  mutate(q.value = q_val(p.value))



# Get the highest loading hypothesis for each factor
factor_labels <- hypotheses_loadings %>%
  select(hypothesis_cleaned, matches("fact_")) %>%
  mutate(hypothesis_cleaned = str_remove_all(hypothesis_cleaned, fixed("'"))) %>% 
  pivot_longer(
    cols = starts_with("fact_"),
    names_to = "factor",
    values_to = "loading"
  ) %>%
  group_by(factor) %>%
  slice_max(abs(loading), n = 1) %>%
  ungroup() %>%
  mutate(
    factor = str_remove(factor, "fact_"),
    # Use stringr::str_wrap instead of str_trunc
    hypothesis_label = str_wrap(hypothesis_cleaned, width = 70)
  ) %>% 
  glimpse

  factor_effects_plot <- factor_effects %>%
    mutate(factor = str_remove(factor, "fact_")) %>% 
    tidylog::left_join(factor_labels, by = "factor") %>%
    # Create a temporary sorting value based on trans rating estimates
    group_by(factor) %>%
    mutate(
      sort_value = estimate[rating_type == "Trans Rating"],
      # Add significance stars based on q-values for all coefficients
      stars = case_when(
        q.value < 0.01 ~ "***",
        q.value < 0.05 ~ "**",
        q.value < 0.1 ~ "*",
        TRUE ~ ""
      ),
      # Set hjust based on coefficient direction
      star_hjust = if_else(estimate < 0, 1.5, -0.5),
      # Set vertical position based on rating type
      star_vjust = if_else(rating_type == "Trans Rating", 1, 6) * (1/5)
    ) %>%
    ungroup()

    factor_effects_plot %>% glimpse


# Update the plot with hypothesis labels and sort by trans rating effect
ggplot(
  factor_effects_plot,
  aes(
    x = reorder(hypothesis_label, sort_value),
    y = estimate,
    ymin = conf.low,
    ymax = conf.high,
    color = rating_type,
    shape = rating_type  # Add shape aesthetic
  )
) +
  geom_hline(yintercept = 0, linetype = "dashed", alpha = 0.5) +
  geom_pointrange(position = position_dodge(width = 0.5)) +
  geom_text(
    aes(hjust = star_hjust, vjust = star_vjust, label = stars),
    size = 5,
    position = position_dodge(width = 0.5),
    show.legend = FALSE
  ) +
  coord_flip(ylim = c(-0.15, 0.27)) +
  labs(
    x = "Most Representative Hypothesis", 
    y = "Association between discussion rating (Z) on\nP(choose trans in post-discussion outcome round)",
    color = "Discussion type",
    shape = "Discussion type"  # Add shape label
  ) +
  scale_color_discrete(
    labels = c(
      "Non-Trans Rating" = "Choice doesn't include trans",
      "Trans Rating" = "Choice includes trans"
    )
  ) +
  scale_shape_discrete(  # Add shape scale with same labels
    labels = c(
      "Non-Trans Rating" = "Choice doesn't include trans",
      "Trans Rating" = "Choice includes trans"
    )
  ) +
  theme_classic() + 
  theme(
    axis.text.y = element_text(size = 8, lineheight = 1),
    legend.position = c(1, 0.02),
    legend.justification = c(1, 0),
    legend.background = element_rect(fill = "white", color = "grey90", size = 0.5),
    plot.margin = margin(t = 10, r = 10, b = 10, l = 10, unit = "pt"),
    legend.title.align = 1,  # Right-align legend title
    legend.text.align = 1    # Right-align legend text
  )

# Increase the height of the saved plot
ggsave("outputs/figs/hypothesis_effects_fct.pdf", width = 10.8, height = 12, scale = 0.7)  # Increased height from 7 to 10


# Export the p-value for the first factor
factor_effects_plot %>% 
filter(term == "fact_1_trans") %>% 
pull(q.value) %>% 
write_stat("outputs/stats/pval_hypothesis_effects_fct_social_equity.tex", digits = 3, p_value = TRUE)

# Export p-value for 

factor_effects_plot %>% 
filter(term == "fact_8_trans") %>% 
pull(q.value) %>% 
write_stat("outputs/stats/pval_hypothesis_effects_fct_8.tex", digits = 3, p_value = TRUE)


factor_effects_plot %>% 
arrange(desc(sort_value)) %>% 
filter(!str_detect(term, "non_trans")) %>% 
relocate(q.value, hypothesis_label) %>% 
view()

factor_effects_plot %>% 
filter(hypothesis_cleaned == "The emphasis on language proficiency as a deciding factor for preference.") %>% 
filter(!str_detect(term, "non_trans")) %>% 
pull(q.value) %>% 
write_stat("outputs/stats/pval_hypothesis_effects_fct_language_proficiency.tex", digits = 2, p_value = TRUE)

#  Same but for "The level of specific positive feedback about delivery performance."

factor_effects_plot %>% 
filter(hypothesis_cleaned == "The level of specific positive feedback about delivery performance.") %>% 
filter(!str_detect(term, "non_trans")) %>% 
pull(q.value) %>% 
write_stat("outputs/stats/pval_hypothesis_effects_fct_delivery_performance.tex", digits = 2, p_value = TRUE)